import pandas as pd
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LinearRegression
import warnings
import numpy as np
import tkinter as tk
warnings.filterwarnings('ignore')



def predict(stock_codes, result_text):
    svr_params = {'kernel': 'rbf', 'C': 1e3, 'gamma': 0.1}
    # 支持向量回归模型的参数，包括kernel（核函数类型）、C（正则化参数）和gamma（核函数系数
    rf_params = {'n_estimators': 10, 'criterion': 'squared_error', 'max_depth': None}
    # 随机森林回归模型的参数，包括n_estimators（决策树数量）、criterion（评估准则）和max_depth（决策树最大深度）

    files = [f'data/{code}.csv' for code in stock_codes]
    # 定义模型列表
    models = [SVR(**svr_params), RandomForestRegressor(**rf_params), LinearRegression()]

    for i, file in enumerate(files):
        # 读取数据
        df = pd.read_csv(file, usecols=['股票名称','日期', '开盘', '收盘', '最高', '最低', '成交量'])

        # 拆分数据集和标签
        X = df[['开盘', '最高', '最低', '成交量']]
        y = df['收盘']

        # 划分训练集和测试集
        train_size = int(len(X) * 0.8)
        X_train, X_test = X[:train_size], X[train_size:]
        y_train, y_test = y[:train_size], y[train_size:]
        print('-' * 40)
        print(df['股票名称'][0])
        print(df[['开盘', '最高', '最低', '成交量']].iloc[-1])
        avg = 0.0
        # 训练模型并进行预测
        for model in models:
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            # 输出均方根误差
            loss = mean_squared_error(y_test, y_pred, squared=False)
            print('-' * 40)
            result_text.insert(tk.END, f'{model.__class__.__name__}模型预测结果：\n')
            result_text.insert(tk.END, f'均方根误差： {loss:.4f}\n')
            # 输出下一次预测结果
            X_input = df[['开盘', '最高', '最低', '成交量']].iloc[-1]
            y_pred_next = model.predict([X_input])
            avg += y_pred_next
            result_text.insert(tk.END, f'根据前一天的开盘值、最高值、最低值、成交量预测第二天的收盘值为: {y_pred_next[0]:.2f}\n')
        result_text.insert(tk.END, '-' * 40 + '\n')
        avg /= 3.0
        print(f'根据三种机器学习算法预测的平均值得出第二天的收盘值为:',np.around(avg, decimals=2))
        result_text.insert(tk.END, '-' * 40 + '\n')
        result_text.insert(tk.END, '\n\n')